package hex.drf;

import hex.ConfusionMatrix;
import hex.VarImp;
import hex.drf.TreeMeasuresCollector.TreeMeasures;
import hex.drf.TreeMeasuresCollector.TreeSSE;
import hex.drf.TreeMeasuresCollector.TreeVotes;
import hex.gbm.DHistogram;
import hex.gbm.DTree;
import hex.gbm.DTree.DecidedNode;
import hex.gbm.DTree.LeafNode;
import hex.gbm.DTree.TreeModel.CompressedTree;
import hex.gbm.DTree.TreeModel.TreeStats;
import hex.gbm.DTree.UndecidedNode;
import hex.gbm.SharedTreeModelBuilder;
import water.*;
import water.H2O.H2OCountedCompleter;
import water.api.*;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.util.*;
import water.util.Log.Tag.Sys;

import java.util.Arrays;
import java.util.Random;

import static hex.drf.TreeMeasuresCollector.asSSE;
import static hex.drf.TreeMeasuresCollector.asVotes;
import static water.util.Utils.div;
import static water.util.Utils.sum;

// Random Forest Trees
public class DRF extends SharedTreeModelBuilder<DRF.DRFModel> {
  static final int API_WEAVER = 1; // This file has auto-gen'd doc & json fields
  static public DocGen.FieldDoc[] DOC_FIELDS; // Initialized from Auto-Gen code.

  static final boolean DEBUG_DETERMINISTIC = false; // enable this for deterministic version of DRF. It will use same seed for each execution. I would prefere here to read this property from system properties.

  @API(help = "Columns to randomly select at each level, or -1 for sqrt(#cols)", filter = Default.class, lmin=-1, lmax=100000)
  int mtries = -1;

  @API(help = "Sample rate, from 0. to 1.0", filter = Default.class, dmin=0, dmax=1, importance=ParamImportance.SECONDARY)
  float sample_rate = 0.6666667f;

  @API(help = "Seed for the random number generator (autogenerated)", filter = Default.class)
  long seed = -1; // To follow R-semantics, each call of RF should provide different seed. -1 means seed autogeneration

  @API(help = "Check non-contiguous group splits for categorical predictors", filter = Default.class, hide = true)
  boolean do_grpsplit = true;

  @API(help="Run on one node only; no network overhead but fewer cpus used.  Suitable for small datasets.", filter=myClassFilter.class, importance=ParamImportance.SECONDARY)
  public boolean build_tree_one_node = false;
  class myClassFilter extends DRFCopyDataBoolean { myClassFilter() { super("source"); } }

  @API(help = "Computed number of split features", importance=ParamImportance.EXPERT)
  protected int _mtry; // FIXME remove and replace by mtries

  @API(help = "Autogenerated seed", importance=ParamImportance.EXPERT)
  protected long _seed; // FIXME remove and replace by seed

  // Fixed seed generator for DRF
  private static final Random _seedGenerator = Utils.getDeterRNG(0xd280524ad7fe0602L);

  // --- Private data handled only on master node
  // Classification or Regression:
  // Tree votes/SSE of individual trees on OOB rows
  private transient TreeMeasures _treeMeasuresOnOOB;
  // Tree votes/SSE per individual features on permutated OOB rows
  private transient TreeMeasures[/*features*/] _treeMeasuresOnSOOB;
  // Variable importance beased on tree split decisions
  private transient float[/*nfeatures*/] _improvPerVar;

  /** DRF model holding serialized tree and implementing logic for scoring a row */
  public static class DRFModel extends DTree.TreeModel {
    static final int API_WEAVER = 1; // This file has auto-gen'd doc & json fields
    static public DocGen.FieldDoc[] DOC_FIELDS; // Initialized from Auto-Gen code.

    @API(help = "Model parameters", json = true)
    private final DRF parameters;    // This is used purely for printing values out.
    @Override public final DRF get_params() { return parameters; }
    @Override public final Request2 job() { return get_params(); }

    @API(help = "Number of columns picked at each split") final int mtries;
    @API(help = "Sample rate") final float sample_rate;
    @API(help = "Seed") final long seed;

    // Params that do not affect model quality:
    //
    public DRFModel(DRF params, Key key, Key dataKey, Key testKey, String names[], String domains[][], String[] cmDomain, int ntrees, int max_depth, int min_rows, int nbins, int mtries, float sample_rate, long seed, int num_folds, float[] priorClassDist, float[] classDist) {
      super(key,dataKey,testKey,names,domains,cmDomain,ntrees, max_depth, min_rows, nbins, num_folds, priorClassDist, classDist);
      this.parameters = Job.hygiene((DRF) params.clone());
      this.mtries = mtries;
      this.sample_rate = sample_rate;
      this.seed = seed;
    }
    private DRFModel(DRFModel prior, DTree[] trees, TreeStats tstats) {
      super(prior, trees, tstats);
      this.parameters = prior.parameters;
      this.mtries = prior.mtries;
      this.sample_rate = prior.sample_rate;
      this.seed = prior.seed;
    }
    private DRFModel(DRFModel prior, double err, ConfusionMatrix cm, VarImp varimp, AUCData validAUC) {
      super(prior, err, cm, varimp, validAUC);
      this.parameters = prior.parameters;
      this.mtries = prior.mtries;
      this.sample_rate = prior.sample_rate;
      this.seed = prior.seed;
    }
    private DRFModel(DRFModel prior, Key[][] treeKeys, double[] errs, ConfusionMatrix[] cms, TreeStats tstats, VarImp varimp, AUCData validAUC) {
      super(prior, treeKeys, errs, cms, tstats, varimp, validAUC);
      this.parameters = prior.parameters;
      this.mtries = prior.mtries;
      this.sample_rate = prior.sample_rate;
      this.seed = prior.seed;
    }

    @Override protected TreeModelType getTreeModelType() { return TreeModelType.DRF; }

    @Override protected float[] score0(double data[], float preds[]) {
      float[] p = super.score0(data, preds);
      int ntrees = ntrees();
      if (p.length==1) { if (ntrees>0) div(p, ntrees); } // regression - compute avg over all trees
      else { // classification
        float s = sum(p);
        if (s>0) div(p, s); // unify over all classes
        p[0] = ModelUtils.getPrediction(p, data);
      }
      return p;
    }
    @Override protected void generateModelDescription(StringBuilder sb) {
      DocGen.HTML.paragraph(sb,"mtries: "+mtries+", Sample rate: "+sample_rate+", Seed: "+seed);
      if (testKey==null && sample_rate==1f) {
        sb.append("<div class=\"alert alert-danger\">There are no out-of-bag data to compute out-of-bag error estimate, since sampling rate is 1!</div>");
      }
    }
    @Override protected void toJavaUnifyPreds(SB bodySb) {
      if (isClassifier()) {
        bodySb.i().p("float sum = 0;").nl();
        bodySb.i().p("for(int i=1; i<preds.length; i++) sum += preds[i];").nl();
        bodySb.i().p("if (sum>0) for(int i=1; i<preds.length; i++) preds[i] /= sum;").nl();
      } else bodySb.i().p("preds[1] = preds[1]/NTREES;").nl();
    }
    @Override protected void setCrossValidationError(ValidatedJob job, double cv_error, water.api.ConfusionMatrix cm, AUCData auc, HitRatio hr) {
      DRFModel drfm = ((DRF)job).makeModel(this, cv_error, cm.cm == null ? null : new ConfusionMatrix(cm.cm, cms[0].nclasses()), this.varimp, auc);
      drfm._have_cv_results = true;
      DKV.put(this._key, drfm); //overwrite this model
    }
  }
  public Frame score( Frame fr ) { return ((DRFModel)UKV.get(dest())).score(fr);  }

  @Override protected Log.Tag.Sys logTag() { return Sys.DRF__; }
  @Override protected DRFModel makeModel(Key outputKey, Key dataKey, Key testKey, int ntrees, String[] names, String[][] domains, String[] cmDomain, float[] priorClassDist, float[] classDist) {
    return new DRFModel(this,outputKey,dataKey,validation==null?null:testKey,names,domains,cmDomain,ntrees, max_depth, min_rows, nbins, mtries, sample_rate, _seed, n_folds, priorClassDist, classDist);
  }

  @Override protected DRFModel makeModel( DRFModel model, double err, ConfusionMatrix cm, VarImp varimp, AUCData validAUC) {
    return new DRFModel(model, err, cm, varimp, validAUC);
  }
  @Override protected DRFModel makeModel( DRFModel model, DTree ktrees[], TreeStats tstats) {
    return new DRFModel(model, ktrees, tstats);
  }
  @Override protected DRFModel updateModel(DRFModel model, DRFModel checkpoint, boolean overwriteCheckpoint) {
    // Do not forget to clone trees in case that we are not going to overwrite checkpoint
    Key[][] treeKeys = null;
    if (!overwriteCheckpoint) throw H2O.unimpl("Cloning of model trees is not implemented yet!");
    else treeKeys = checkpoint.treeKeys;
    return new DRFModel(model, treeKeys, checkpoint.errs, checkpoint.cms, checkpoint.treeStats, checkpoint.varimp, checkpoint.validAUC);
  }
  public DRF() { description = "Distributed RF"; ntrees = 50; max_depth = 20; min_rows = 1; }

  /** Return the query link to this page */
  public static String link(Key k, String content) {
    RString rs = new RString("<a href='/2/DRF.query?source=%$key'>%content</a>");
    rs.replace("key", k.toString());
    rs.replace("content", content);
    return rs.toString();
  }

  // ==========================================================================

  /** Compute a DRF tree.
   *
   * Start by splitting all the data according to some criteria (minimize
   * variance at the leaves).  Record on each row which split it goes to, and
   * assign a split number to it (for next pass).  On *this* pass, use the
   * split-number to build a per-split histogram, with a per-histogram-bucket
   * variance. */
  @Override protected void execImpl() {
    try {
      logStart();
      buildModel(seed);
      if (n_folds > 0) CrossValUtils.crossValidate(this);
    } finally {
      remove();                   // Remove Job
      // Ugly hack updating job state carried as parameters inside a model
      state = UKV.<Job>get(self()).state;
      new TAtomic<DRFModel>() {
        @Override
        public DRFModel atomic(DRFModel m) {
          if (m != null) m.get_params().state = state;
          return m;
        }
      }.invoke(dest());
    }
  }

  @Override protected Response redirect() {
    return DRFProgressPage.redirect(this, self(), dest());
  }

  @SuppressWarnings("unused")
  @Override protected void init() {
    super.init();
    // Initialize local variables
    _mtry = (mtries==-1) ? // classification: mtry=sqrt(_ncols), regression: mtry=_ncols/3
        ( classification ? Math.max((int)Math.sqrt(_ncols),1) : Math.max(_ncols/3,1))  : mtries;
    if (!(1 <= _mtry && _mtry <= _ncols)) throw new IllegalArgumentException("Computed mtry should be in interval <1,#cols> but it is " + _mtry);
    if (!(0.0 < sample_rate && sample_rate <= 1.0)) throw new IllegalArgumentException("Sample rate should be interval (0,1> but it is " + sample_rate);
    if (DEBUG_DETERMINISTIC && seed == -1) _seed = 0x1321e74a0192470cL; // fixed version of seed
    else if (seed == -1) _seed = _seedGenerator.nextLong(); else _seed = seed;
    if (sample_rate==1f && validation!=null)
      Log.warn(Sys.DRF__, "Sample rate is 100% and no validation dataset is specified. There are no OOB data to compute out-of-bag error estimation!");
    if (!classification && do_grpsplit) {
      Log.info(Sys.DRF__, "Group splitting not supported for DRF regression. Forcing group splitting to false.");
      do_grpsplit = false;
    }
  }

  @Override protected void initAlgo(DRFModel initialModel) {
    // Initialize TreeVotes for classification, MSE arrays for regression
    if (importance) initTreeMeasurements();
  }
  @Override protected void initWorkFrame(DRFModel initialModel, Frame fr) {
    // Append number of trees participating in on-the-fly scoring
    fr.add("OUT_BAG_TREES", response.makeZero());
    // Prepare working columns
    new SetWrkTask().doAll(fr);
    // If there was a check point recompute tree_<_> and oob columns based on predictions from previous trees
    // but only if OOB validation is requested.
    if (validation==null && checkpoint!=null) {
      Timer t = new Timer();
      // Compute oob votes for each output level
      new OOBScorer(_ncols, _nclass, sample_rate, initialModel.treeKeys).doAll(fr);
      Log.info(logTag(), "Reconstructing oob stats from checkpointed model took " + t);
    }
  }

  @Override protected DRFModel buildModel( DRFModel model, final Frame fr, String names[], String domains[][], final Timer t_build ) {
    // The RNG used to pick split columns
    Random rand = createRNG(_seed);
    // To be deterministic get random numbers for previous trees and
    // put random generator to the same state
    for (int i=0; i<_ntreesFromCheckpoint; i++) rand.nextLong();

    int tid;
    DTree[] ktrees = null;
    // Prepare tree statistics
    TreeStats tstats = model.treeStats!=null ? model.treeStats : new TreeStats();
    // Build trees until we hit the limit
    for( tid=0; tid<ntrees; tid++) { // Building tid-tree
      if (tid!=0 || checkpoint==null) { // do not make initial scoring if model already exist
        model = doScoring(model, fr, ktrees, tid, tstats, tid==0, !hasValidation(), build_tree_one_node);
      }
      // At each iteration build K trees (K = nclass = response column domain size)

      // TODO: parallelize more? build more than k trees at each time, we need to care about temporary data
      // Idea: launch more DRF at once.
      Timer kb_timer = new Timer();
      ktrees = buildNextKTrees(fr,_mtry,sample_rate,rand,tid);
      Log.info(logTag(), (tid+1) + ". tree was built " + kb_timer.toString());
      if( !Job.isRunning(self()) ) break; // If canceled during building, do not bulkscore

      // Check latest predictions
      tstats.updateBy(ktrees);
    }
    if( Job.isRunning(self()) ) { // do not perform final scoring and finish
      model = doScoring(model, fr, ktrees, tid, tstats, true, !hasValidation(), build_tree_one_node);
      // Make sure that we did not miss any votes
//      assert !importance || _treeMeasuresOnOOB.npredictors() == _treeMeasuresOnSOOB[0/*variable*/].npredictors() : "Missing some tree votes in variable importance voting?!";
    }

    return model;
  }

  private void initTreeMeasurements() {
    assert importance : "Tree votes should be initialized only if variable importance is requested!";
    _improvPerVar = new float[_ncols];
    // Preallocate tree votes
    if (classification) {
      _treeMeasuresOnOOB  = new TreeVotes(ntrees);
      _treeMeasuresOnSOOB = new TreeVotes[_ncols];
      for (int i=0; i<_ncols; i++) _treeMeasuresOnSOOB[i] = new TreeVotes(ntrees);
    } else {
      _treeMeasuresOnOOB  = new TreeSSE(ntrees);
      _treeMeasuresOnSOOB = new TreeSSE[_ncols];
      for (int i=0; i<_ncols; i++) _treeMeasuresOnSOOB[i] = new TreeSSE(ntrees);
    }
  }

//  /** On-the-fly version for varimp. After generation a new tree, its tree votes are collected on shuffled
//   * OOB rows and variable importance is recomputed.
//   * <p>
//   * The <a href="http://www.stat.berkeley.edu/~breiman/RandomForests/cc_home.htm#varimp">page</a> says:
//   * <cite>
//   * "In every tree grown in the forest, put down the oob cases and count the number of votes cast for the correct class.
//   * Now randomly permute the values of variable m in the oob cases and put these cases down the tree.
//   * Subtract the number of votes for the correct class in the variable-m-permuted oob data from the number of votes
//   * for the correct class in the untouched oob data.
//   * The average of this number over all trees in the forest is the raw importance score for variable m."
//   * </cite>
//   * </p>
//   * */
//  @Override
//  protected VarImp doVarImpCalc(final DRFModel model, DTree[] ktrees, final int tid, final Frame fTrain, boolean scale) {
//    // Check if we have already serialized 'ktrees'-trees in the model
//    assert model.ntrees()-1-_ntreesFromCheckpoint == tid : "Cannot compute DRF varimp since 'ktrees' are not serialized in the model! tid="+tid;
//    assert _treeMeasuresOnOOB.npredictors()-1 == tid : "Tree votes over OOB rows for this tree (var ktrees) were not found!";
//    // Compute tree votes over shuffled data
//    final CompressedTree[/*nclass*/] theTree = model.ctree(tid); // get the last tree FIXME we should pass only keys
//    final int nclasses = model.nclasses();
//    Futures fs = new Futures();
//    for (int var=0; var<_ncols; var++) {
//      final int variable = var;
//      H2OCountedCompleter task4var = classification ? new H2OCountedCompleter() {
//        @Override public void compute2() {
//          // Compute this tree votes over all data over given variable
//          TreeVotes cd = TreeMeasuresCollector.collectVotes(theTree, nclasses, fTrain, _ncols, sample_rate, variable);
//          assert cd.npredictors() == 1;
//          asVotes(_treeMeasuresOnSOOB[variable]).append(cd);
//          tryComplete();
//        }
//      } : /* regression */ new H2OCountedCompleter() {
//        @Override public void compute2() {
//          // Compute this tree votes over all data over given variable
//          TreeSSE cd = TreeMeasuresCollector.collectSSE(theTree, nclasses, fTrain, _ncols, sample_rate, variable);
//          assert cd.npredictors() == 1;
//          asSSE(_treeMeasuresOnSOOB[variable]).append(cd);
//          tryComplete();
//        }
//      };
//      fs.add(task4var);
//      H2O.submitTask(task4var); // Fork computation
//    }
//    fs.blockForPending(); // Wait for results
//    // Compute varimp for individual features (_ncols)
//    final float[] varimp   = new float[_ncols]; // output variable importance
//    final float[] varimpSD = new float[_ncols]; // output variable importance sd
//    for (int var=0; var<_ncols; var++) {
//      double[/*2*/] imp = classification ? asVotes(_treeMeasuresOnSOOB[var]).imp(asVotes(_treeMeasuresOnOOB)) :  asSSE(_treeMeasuresOnSOOB[var]).imp(asSSE(_treeMeasuresOnOOB));
//      varimp  [var] = (float) imp[0];
//      varimpSD[var] = (float) imp[1];
//    }
//    return new VarImp.VarImpMDA(varimp, varimpSD, model.ntrees());
//  }

  /** Compute relative variable importance for RF model.
   *
   *  See (45), (35) formulas in Friedman: Greedy Function Approximation: A Gradient boosting machine.
   *  Algo used here can be used for computation individual importance of features per output class. */
  @Override protected VarImp doVarImpCalc(DRFModel model, DTree[] ktrees, int tid, Frame validationFrame, boolean scale) {
    assert model.ntrees()-1-_ntreesFromCheckpoint == tid : "varimp computation expect model with already serialized trees: tid="+tid;
    // Iterates over k-tree
    for (DTree t : ktrees) { // Iterate over trees
      if (t!=null) {
        for (int n = 0; n< t.len()-t.leaves; n++)
          if (t.node(n) instanceof DecidedNode) { // it is split node
            DTree.Split split = t.decided(n)._split;
            if (split.col()!=-1) // Skip impossible splits ~ leafs
              _improvPerVar[split.col()] += split.improvement(); // least squares improvement
          }
      }
    }
    // Compute variable importance for all trees in model
    float[] varimp   = new float[model.nfeatures()];

    int   ntreesTotal = model.ntrees() * model.nclasses();
    int   maxVar = 0;
    for (int var=0; var<_improvPerVar.length; var++) {
      varimp[var] = _improvPerVar[var] / ntreesTotal;
      if (varimp[var] > varimp[maxVar]) maxVar = var;
    }
    // scale varimp to scale 0..100
    if (scale) {
      float maxVal = varimp[maxVar];
      for (int var=0; var<varimp.length; var++) varimp[var] /= maxVal;
    }

    return new VarImp.VarImpRI(varimp);
  }

  @Override public boolean supportsBagging() { return true; }

  /** Fill work columns:
   *   - classification: set 1 in the corresponding wrk col according to row response
   *   - regression:     copy response into work column (there is only 1 work column) */

  private class SetWrkTask extends MRTask2<SetWrkTask> {
    @Override public void map( Chunk chks[] ) {
      Chunk cy = chk_resp(chks);
      for( int i=0; i<cy._len; i++ ) {
        if( cy.isNA0(i) ) continue;
        if (classification) {
          int cls = (int)cy.at80(i);
          chk_work(chks,cls).set0(i,1L);
        } else {
          float pred = (float) cy.at0(i);
          chk_work(chks,0).set0(i,pred);
        }
      }
    }
  }

  // --------------------------------------------------------------------------
  // Build the next random k-trees representing tid-th tree
  private DTree[] buildNextKTrees(Frame fr, int mtrys, float sample_rate, Random rand, int tid) {
    // We're going to build K (nclass) trees - each focused on correcting
    // errors for a single class.
    final DTree[] ktrees = new DTree[_nclass];

    // Initial set of histograms.  All trees; one leaf per tree (the root
    // leaf); all columns
    DHistogram hcs[][][] = new DHistogram[_nclass][1/*just root leaf*/][_ncols];

    // Adjust nbins for the top-levels
    int adj_nbins = Math.max((1<<(10-0)),nbins);

    // Use for all k-trees the same seed. NOTE: this is only to make a fair
    // view for all k-trees
    long rseed = rand.nextLong();
    // Initially setup as-if an empty-split had just happened
    for( int k=0; k<_nclass; k++ ) {
      assert (_distribution!=null && classification) || (_distribution==null && !classification);
      if( _distribution == null || _distribution[k] != 0 ) { // Ignore missing classes
        // The Boolean Optimization cannot be applied here for RF !
        // This optimization assumes the 2nd tree of a 2-class system is the
        // inverse of the first.  This is false for DRF (and true for GBM) -
        // DRF picks a random different set of columns for the 2nd tree.
        //if( k==1 && _nclass==2 ) continue;
        ktrees[k] = new DRFTree(fr,_ncols,(char)nbins,(char)_nclass,min_rows,mtrys,rseed);
        boolean isBinom = classification;
        new DRFUndecidedNode(ktrees[k],-1, DHistogram.initialHist(fr,_ncols,adj_nbins,hcs[k][0],min_rows,do_grpsplit,isBinom) ); // The "root" node
      }
    }

    // Sample - mark the lines by putting 'OUT_OF_BAG' into nid(<klass>) vector
    Timer t_1 = new Timer();
    Sample ss[] = new Sample[_nclass];
    for( int k=0; k<_nclass; k++)
      if (ktrees[k] != null) ss[k] = new Sample((DRFTree)ktrees[k], sample_rate).dfork(0,new Frame(vec_nids(fr,k),vec_resp(fr,k)), build_tree_one_node);
    for( int k=0; k<_nclass; k++)
      if( ss[k] != null ) ss[k].getResult();
    Log.debug(Sys.DRF__, "Sampling took: + " + t_1);

    int[] leafs = new int[_nclass]; // Define a "working set" of leaf splits, from leafs[i] to tree._len for each tree i

    // ----
    // One Big Loop till the ktrees are of proper depth.
    // Adds a layer to the trees each pass.
    Timer t_2 = new Timer();
    int depth=0;
    for( ; depth<max_depth; depth++ ) {
      if( !Job.isRunning(self()) ) return null;

      hcs = buildLayer(fr, ktrees, leafs, hcs, true, build_tree_one_node);

      // If we did not make any new splits, then the tree is split-to-death
      if( hcs == null ) break;
    }
    Log.debug(Sys.DRF__, "Tree build took: " + t_2);

    // Each tree bottomed-out in a DecidedNode; go 1 more level and insert
    // LeafNodes to hold predictions.
    Timer t_3 = new Timer();
    for( int k=0; k<_nclass; k++ ) {
      DTree tree = ktrees[k];
      if( tree == null ) continue;
      int leaf = leafs[k] = tree.len();
      for( int nid=0; nid<leaf; nid++ ) {
        if( tree.node(nid) instanceof DecidedNode ) {
          DecidedNode dn = tree.decided(nid);
          for( int i=0; i<dn._nids.length; i++ ) {
            int cnid = dn._nids[i];
            if( cnid == -1 || // Bottomed out (predictors or responses known constant)
                tree.node(cnid) instanceof UndecidedNode || // Or chopped off for depth
                (tree.node(cnid) instanceof DecidedNode &&  // Or not possible to split
                 ((DecidedNode)tree.node(cnid))._split.col()==-1) ) {
              LeafNode ln = new DRFLeafNode(tree,nid);
              ln._pred = dn.pred(i);  // Set prediction into the leaf
              dn._nids[i] = ln.nid(); // Mark a leaf here
            }
          }
          // Handle the trivial non-splitting tree
          if( nid==0 && dn._split.col() == -1 )
            new DRFLeafNode(tree,-1,0);
        }
      }
    } // -- k-trees are done
    Log.debug(Sys.DRF__, "Nodes propagation: " + t_3);


    // ----
    // Move rows into the final leaf rows
    Timer t_4 = new Timer();
    CollectPreds cp = new CollectPreds(ktrees,leafs).doAll(fr,build_tree_one_node);
    if (importance) {
      if (classification)   asVotes(_treeMeasuresOnOOB).append(cp.rightVotes, cp.allRows); // Track right votes over OOB rows for this tree
      else /* regression */ asSSE  (_treeMeasuresOnOOB).append(cp.sse, cp.allRows);
    }
    Log.debug(Sys.DRF__, "CollectPreds done: " + t_4);

    // Collect leaves stats
    for (int i=0; i<ktrees.length; i++)
      if( ktrees[i] != null )
        ktrees[i].leaves = ktrees[i].len() - leafs[i];
    // DEBUG: Print the generated K trees
    //printGenerateTrees(ktrees);

    return ktrees;
  }

  // Read the 'tree' columns, do model-specific math and put the results in the
  // fs[] array, and return the sum.  Dividing any fs[] element by the sum
  // turns the results into a probability distribution.
  @Override protected float score1( Chunk chks[], float fs[/*nclass*/], int row ) {
    float sum=0;
    for( int k=0; k<_nclass; k++ ) // Sum across of likelyhoods
      sum+=(fs[k+1]=(float)chk_tree(chks,k).at0(row));
    if (_nclass == 1) sum /= (float)chk_oobt(chks).at0(row); // for regression average per trees voted for this row (only trees which have row in "out-of-bag"
    return sum;
  }

  @Override protected boolean inBagRow(Chunk[] chks, int row) {
    return chk_oobt(chks).at80(row) == 0;
  }

  // Collect and write predictions into leafs.
  private class CollectPreds extends MRTask2<CollectPreds> {
    /* @IN  */ final DTree _trees[]; // Read-only, shared (except at the histograms in the Nodes)
    /* @OUT */ long rightVotes; // number of right votes over OOB rows (performed by this tree) represented by DTree[] _trees
    /* @OUT */ long allRows;    // number of all OOB rows (sampled by this tree)
    /* @OUT */ float sse;      // Sum of squares for this tree only
    CollectPreds(DTree trees[], int leafs[]) { _trees=trees; }
    @Override public void map( Chunk[] chks ) {
      final Chunk    y       = importance ? chk_resp(chks) : null; // Response
      final float [] rpred   = importance ? new float [1+_nclass] : null; // Row prediction
      final double[] rowdata = importance ? new double[_ncols] : null; // Pre-allocated row data
      final Chunk   oobt  = chk_oobt(chks); // Out-of-bag rows counter over all trees
      // Iterate over all rows
      for( int row=0; row<oobt._len; row++ ) {
        boolean wasOOBRow = false;
        // For all tree (i.e., k-classes)
        for( int k=0; k<_nclass; k++ ) {
          final DTree tree = _trees[k];
          if( tree == null ) continue; // Empty class is ignored
          // If we have all constant responses, then we do not split even the
          // root and the residuals should be zero.
          if( tree.root() instanceof LeafNode ) continue;
          final Chunk nids = chk_nids(chks,k); // Node-ids  for this tree/class
          final Chunk ct   = chk_tree(chks,k); // k-tree working column holding votes for given row
          int nid = (int)nids.at80(row);         // Get Node to decide from
          // Update only out-of-bag rows
          // This is out-of-bag row - but we would like to track on-the-fly prediction for the row
          if( isOOBRow(nid) ) { // The row should be OOB for all k-trees !!!
            assert k==0 || wasOOBRow : "Something is wrong: k-class trees oob row computing is broken! All k-trees should agree on oob row!";
            wasOOBRow = true;
            nid = oob2Nid(nid);
            if( tree.node(nid) instanceof UndecidedNode ) // If we bottomed out the tree
              nid = tree.node(nid).pid();                 // Then take parent's decision
            DecidedNode dn = tree.decided(nid);           // Must have a decision point
            if( dn._split.col() == -1 )     // Unable to decide?
              dn = tree.decided(tree.node(nid).pid());    // Then take parent's decision
            int leafnid = dn.ns(chks,row); // Decide down to a leafnode
            // Setup Tree(i) - on the fly prediction of i-tree for row-th row
            //   - for classification: cumulative number of votes for this row
            //   - for regression: cumulative sum of prediction of each tree - has to be normalized by number of trees
            double prediction = ((LeafNode)tree.node(leafnid)).pred(); // Prediction for this k-class and this row
            if (importance) rpred[1+k] = (float) prediction; // for both regression and classification
            ct.set0(row, (float)(ct.at0(row) +  prediction));
            // For this tree this row is out-of-bag - i.e., a tree voted for this row
            oobt.set0(row, _nclass>1?1:oobt.at0(row)+1); // for regression track number of trees, for classification boolean flag is enough
          }
          // reset help column for this row and this k-class
          nids.set0(row,0);
        } /* end of k-trees iteration */
        if (importance) {
          if (wasOOBRow && !y.isNA0(row)) {
            if (classification) {
              int treePred = ModelUtils.getPrediction(rpred, data_row(chks,row, rowdata));
              int actuPred = (int) y.at80(row);
              if (treePred==actuPred) rightVotes++; // No miss !
            } else { // regression
              float  treePred = rpred[1];
              float  actuPred = (float) y.at0(row);
              sse += (actuPred-treePred)*(actuPred-treePred);
            }
            allRows++;
          }
        }
      }
    }
    @Override public void reduce(CollectPreds mrt) {
      rightVotes += mrt.rightVotes;
      allRows    += mrt.allRows;
      sse        += mrt.sse;
    }
  }

  // A standard DTree with a few more bits.  Support for sampling during
  // training, and replaying the sample later on the identical dataset to
  // e.g. compute OOBEE.
  static class DRFTree extends DTree {
    final int _mtrys;           // Number of columns to choose amongst in splits
    final long _seeds[];        // One seed for each chunk, for sampling
    final transient Random _rand; // RNG for split decisions & sampling
    DRFTree( Frame fr, int ncols, char nbins, char nclass, int min_rows, int mtrys, long seed ) {
      super(fr._names, ncols, nbins, nclass, min_rows, seed);
      _mtrys = mtrys;
      _rand = createRNG(seed);
      _seeds = new long[fr.vecs()[0].nChunks()];
      for( int i=0; i<_seeds.length; i++ )
        _seeds[i] = _rand.nextLong();
    }
    // Return a deterministic chunk-local RNG.  Can be kinda expensive.
    @Override public Random rngForChunk( int cidx ) {
      long seed = _seeds[cidx];
      return createRNG(seed);
    }
  }

  @Override protected DecidedNode makeDecided( UndecidedNode udn, DHistogram hs[] ) {
    return new DRFDecidedNode(udn,hs);
  }

  // DRF DTree decision node: same as the normal DecidedNode, but specifies a
  // decision algorithm given complete histograms on all columns.
  // DRF algo: find the lowest error amongst a random mtry columns.
  static class DRFDecidedNode extends DecidedNode {
    DRFDecidedNode( UndecidedNode n, DHistogram hs[] ) { super(n,hs); }
    @Override public DRFUndecidedNode makeUndecidedNode( DHistogram hs[] ) {
      return new DRFUndecidedNode(_tree,_nid, hs);
    }

    // Find the column with the best split (lowest score).
    @Override public DTree.Split bestCol( UndecidedNode u, DHistogram hs[] ) {
      DTree.Split best = new DTree.Split(-1,-1,null,(byte)0,Double.MAX_VALUE,Double.MAX_VALUE,0L,0L,0,0);
      if( hs == null ) return best;
      for( int i=0; i<u._scoreCols.length; i++ ) {
        int col = u._scoreCols[i];
        DTree.Split s = hs[col].scoreMSE(col);
        if( s == null ) continue;
        if( s.se() < best.se() ) best = s;
        if( s.se() <= 0 ) break; // No point in looking further!
      }
      return best;
    }
  }

  // DRF DTree undecided node: same as the normal UndecidedNode, but specifies
  // a list of columns to score on now, and then decide over later.
  // DRF algo: pick a random mtry columns
  static class DRFUndecidedNode extends UndecidedNode {
    DRFUndecidedNode( DTree tree, int pid, DHistogram[] hs ) { super(tree,pid, hs); }

    // Randomly select mtry columns to 'score' in following pass over the data.
    @Override public int[] scoreCols( DHistogram[] hs ) {
      DRFTree tree = (DRFTree)_tree;
      int[] cols = new int[hs.length];
      int len=0;
      // Gather all active columns to choose from.
      for( int i=0; i<hs.length; i++ ) {
        if( hs[i]==null ) continue; // Ignore not-tracked cols
        assert hs[i]._min < hs[i]._maxEx && hs[i].nbins() > 1 : "broken histo range "+hs[i];
        cols[len++] = i;        // Gather active column
      }
      int choices = len;        // Number of columns I can choose from
      assert choices > 0;

      // Draw up to mtry columns at random without replacement.
      for( int i=0; i<tree._mtrys; i++ ) {
        if( len == 0 ) break;   // Out of choices!
        int idx2 = tree._rand.nextInt(len);
        int col = cols[idx2];     // The chosen column
        cols[idx2] = cols[--len]; // Compress out of array; do not choose again
        cols[len] = col;          // Swap chosen in just after 'len'
      }
      assert choices - len > 0;
      return Arrays.copyOfRange(cols,len,choices);
    }
  }

  static class DRFLeafNode extends LeafNode {
    DRFLeafNode( DTree tree, int pid ) { super(tree,pid); }
    DRFLeafNode( DTree tree, int pid, int nid ) { super(tree,pid,nid); }
    // Insert just the predictions: a single byte/short if we are predicting a
    // single class, or else the full distribution.
    @Override protected AutoBuffer compress(AutoBuffer ab) { assert !Double.isNaN(pred()); return ab.put4f((float)pred()); }
    @Override protected int size() { return 4; }
  }

  // Deterministic sampling
  static class Sample extends MRTask2<Sample> {
    final DRFTree _tree;
    final float _rate;
    Sample( DRFTree tree, float rate ) { _tree = tree; _rate = rate; }
    @Override public void map( Chunk nids, Chunk ys ) {
      Random rand = _tree.rngForChunk(nids.cidx());
      for( int row=0; row<nids._len; row++ )
        if( rand.nextFloat() >= _rate || Double.isNaN(ys.at0(row)) ) {
          nids.set0(row, OUT_OF_BAG);     // Flag row as being ignored by sampling
        }
    }
  }

  /**
   * Cross-Validate a DRF model by building new models on N train/test holdout splits
   * @param splits Frames containing train/test splits
   * @param cv_preds Array of Frames to store the predictions for each cross-validation run
   * @param offsets Array to store the offsets of starting row indices for each cross-validation run
   * @param i Which fold of cross-validation to perform
   */
  @Override public void crossValidate(Frame[] splits, Frame[] cv_preds, long[] offsets, int i) {
    // Train a clone with slightly modified parameters (to account for cross-validation)
    DRF cv = (DRF) this.clone();
//    cv.importance = false; //Don't compute variable importance for N CV-folds
    cv.genericCrossValidation(splits, offsets, i);
    cv_preds[i] = ((DRFModel) UKV.get(cv.dest())).score(cv.validation); // cv_preds is escaping the context of this function and needs to be DELETED by the caller!!!
  }
}
